from torch.utils.data import DataSet
from datasets import load_from_disk

class MyDataset(DataSet):
    # 初始化
    def __init__(self, split):
        # 磁盘加载数据
        self.dataset = load_from_disk(r"D:/")
        if split == "train" | split == "test" | split == "validation":
            self.dataset = self.dataset[split]
        else:
            print("数据名错误！")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        text = self.dataset[item]["text"]
        label = self.dataset[item]["label"]
        return text,label

if __name__ == '__main__':
    dataset = MyDataset("test")
    for data in dataset:
        print(data)